#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

import os
import io
import argparse
import datetime

# import numpy as np
import sqlalchemy
from sqlalchemy import text
import pandas as pd
from sshtunnel import SSHTunnelForwarder


def print_to_string(*args, **kwargs):
    output = io.StringIO()
    print(*args, file=output, **kwargs)
    contents = output.getvalue()
    output.close()
    return contents


def parse_args():
    parser = argparse.ArgumentParser(description="Parse results from tf benchmark runs")
    parser.add_argument(
        "filename", type=str, help="Log file to prase or directory containing log files"
    )
    args = parser.parse_args()
    files = []
    if os.path.isdir(args.filename):
        all_files = os.listdir(args.filename)
        for name in all_files:
            if "log" not in name:
                continue
            files.append(os.path.join(args.filename, name))
    else:
        files = [args.filename]
    args.files = files
    return args


def get_log_params(logfile):
    print("logfile=", logfile)
    branch_name = " "
    node_id = " "
    gpu_arch = " "
    hip_vers = " "
    compute_units = 0
    environment = " "
    rocm_vers = " "
    for line in open(logfile):
        if "Branch name" in line:
            lst = line.split()
            branch_name = lst[2]
        if "On branch" in line:
            lst = line.split()
            branch_name = lst[2]
        if "Node name" in line:
            lst = line.split()
            node_id = lst[2]
        if "GPU_arch" in line:
            lst = line.split()
            gpu_arch = lst[2]
        if "HIP version" in line:
            lst = line.split()
            hip_vers = lst[2]
        if "Compute Unit" in line:
            lst = line.split()
            compute_units = lst[2]
        if "Environment type" in line:
            lst = line.split()
            environment = lst[2]
        if "InstalledDir" in line:
            lst = line.split()
            rocm_vers = lst[1][
                lst[1].find("/opt/rocm-") + len("/opt/rocm-") : lst[1].rfind(
                    "/llvm/bin"
                )
            ]
    return (
        branch_name,
        node_id,
        gpu_arch,
        compute_units,
        rocm_vers,
        hip_vers,
        environment,
    )


def parse_logfile(logfile):
    glue = ""
    res = []
    tests = []
    kernels = []
    tflops = []
    dtype = []
    alayout = []
    blayout = []
    M = []
    N = []
    K = []
    StrideA = []
    StrideB = []
    StrideC = []
    if "perf_gemm" in logfile and "gemm_bilinear" not in logfile:
        for line in open(logfile):
            if "Best Perf" in line:
                lst = line.split()
                if len(lst) >= 37:  # the line is complete
                    tests.append(glue.join(lst[5:30]))
                    kernels.append(glue.join(lst[37:]))
                    tflops.append(lst[33])
                    dtype.append(lst[5])
                    alayout.append(lst[8])
                    blayout.append(lst[11])
                    M.append(lst[14])
                    N.append(lst[17])
                    K.append(lst[20])
                    StrideA.append(lst[23])
                    StrideB.append(lst[26])
                    StrideC.append(lst[29])
                elif len(lst) < 37 and len(lst) >= 33:  # the tflops are available
                    tests.append(glue.join(lst[5:30]))
                    kernels.append("N/A")
                    tflops.append(lst[33])
                    dtype.append(lst[5])
                    alayout.append(lst[8])
                    blayout.append(lst[11])
                    M.append(lst[14])
                    N.append(lst[17])
                    K.append(lst[20])
                    StrideA.append(lst[23])
                    StrideB.append(lst[26])
                    StrideC.append(lst[29])
                    print("warning: incomplete line:", lst)
                elif len(lst) < 33:  # even the tflops are not available
                    print("Error in ckProfiler output!")
                    print("warning: incomplete line=", lst)
        # sort results
        # sorted_tests = sorted(tests)
        res = [x for _, x in sorted(zip(tests, tflops))]
        # sorted_kernels = [x for _,x in sorted(zip(tests,kernels))]
        # test_list = list(range(1, len(tests) + 1))
    # parse conv_fwd and conv_bwd performance tests:
    elif "conv_fwd" in logfile or "conv_bwd" in logfile:
        for line in open(logfile):
            if "tflops:" in line:
                lst = line.split()
                res.append(lst[1])
    # parse all other performance tests:
    elif (
        "resnet50" in logfile
        or "batched_gemm" in logfile
        or "grouped_gemm" in logfile
        or "gemm_bilinear" in logfile
        or "reduction" in logfile
    ):
        for line in open(logfile):
            if "Best Perf" in line:
                lst = line.split()
                res.append(lst[4])
    elif "onnx_gemm" in logfile:
        for line in open(logfile):
            if "Best Perf" in line:
                lst = line.split()
                res.append(lst[33])
    elif "splitK_gemm" in logfile or "mixed_gemm" in logfile:
        for line in open(logfile):
            if "Best Perf" in line:
                lst = line.split()
                res.append(lst[36])
    elif "perf_fmha" in logfile:
        for line in open(logfile):
            if "TFlops" in line:
                lst = line.split()
                line_dict = dict(zip(lst[1:], lst))
                res.append(line_dict["TFlops,"])
    elif "perf_tile_gemm_basic" in logfile or "perf_tile_gemm_mem_pipeline" in logfile:
        for line in open(logfile):
            if "TFlops" in line:
                lst = line.split()
                line_dict = dict(zip(lst[1:], lst))
                res.append(line_dict["TFlops,"])
    return res


def get_baseline(table, connection):
    query = text(
        """SELECT * from """
        + table
        + """ WHERE Datetime = (SELECT MAX(Datetime) FROM """
        + table
        + """ where Branch_ID='develop' );"""
    )
    return pd.read_sql(query, connection)


def store_new_test_result(
    table_name,
    test_results,
    testlist,
    branch_name,
    node_id,
    gpu_arch,
    compute_units,
    rocm_vers,
    hip_vers,
    environment,
    connection,
):
    params = [
        str(branch_name),
        str(node_id),
        str(gpu_arch),
        compute_units,
        str(rocm_vers),
        str(hip_vers),
        str(environment),
        str(datetime.datetime.now()),
    ]
    df = pd.DataFrame(
        data=[params],
        columns=[
            "Branch_ID",
            "Node_ID",
            "GPU_arch",
            "Compute Units",
            "ROCM_version",
            "HIP_version",
            "Environment",
            "Datetime",
        ],
    )
    df_add = pd.DataFrame(data=[test_results], columns=testlist)
    df = pd.concat([df, df_add], axis=1)
    # print("new test results dataframe:",df)
    df.to_sql(table_name, connection, if_exists="append", index=False)
    return 0


def compare_test_to_baseline(baseline, test, testlist):
    regression = 0
    if not baseline.empty:
        base = baseline[testlist].to_numpy(dtype="float")
        base_list = base[0]
        ave_perf = 0
        for i in range(len(base_list)):
            # success criterion:
            if base_list[i] > 1.01 * float(test[i]):
                print(
                    "test # ",
                    i,
                    "shows regression by {:.3f}%".format(
                        (float(test[i]) - base_list[i]) / base_list[i] * 100
                    ),
                )
                regression = 1
            if base_list[i] > 0:
                ave_perf = ave_perf + float(test[i]) / base_list[i]
        if regression == 0:
            print("no regressions found")
        ave_perf = ave_perf / len(base_list)
        print("average performance relative to baseline:", ave_perf)
    else:
        print("could not find a baseline")
    return regression


"""
def post_test_params(tlist,connection):
    sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))]
    sorted_alayout = [x for _,x in sorted(zip(tests,alayout))]
    sorted_blayout = [x for _,x in sorted(zip(tests,blayout))]
    sorted_M = [x for _,x in sorted(zip(tests,M))]
    sorted_N = [x for _,x in sorted(zip(tests,N))]
    sorted_K = [x for _,x in sorted(zip(tests,K))]
    sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))]
    sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))]
    sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))]
    ck_gemm_params=[tlist,sorted_dtypes,sorted_alayout,sorted_blayout,
                sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB,
                sorted_StrideC]
    df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type',
        'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC'])
    print(df)

    dtypes = {
        'Test_number': Integer(),
        'Data_type': NVARCHAR(length=5),
        'Alayout': NVARCHAR(length=12),
        'Blayout': NVARCHAR(length=12),
        'M': Integer(),
        'N': Integer(),
        'K': Integer(),
        'StrideA': Integer(),
        'StrideB': Integer(),
        'StrideC': Integer()
        }
    df.to_sql("ck_gemm_test_params",connection,if_exists='replace',index=False, dtype=dtypes)
"""


def main():
    args = parse_args()
    results = []
    tflops_base = []
    testlist = []
    # parse the test parameters from the logfile
    for filename in args.files:
        (
            branch_name,
            node_id,
            gpu_arch,
            compute_units,
            rocm_vers,
            hip_vers,
            environment,
        ) = get_log_params(filename)

    print("Branch name:", branch_name)
    print("Node name:", node_id)
    print("GPU_arch:", gpu_arch)
    print("Compute units:", compute_units)
    print("ROCM_version:", rocm_vers)
    print("HIP_version:", hip_vers)
    print("Environment:", environment)
    # parse results, get the Tflops value for "Best Perf" kernels
    results = parse_logfile(filename)

    print("Number of tests:", len(results))
    sql_hostname = "127.0.0.1"
    sql_username = os.environ["dbuser"]
    sql_password = os.environ["dbpassword"]
    sql_main_database = os.environ["ck_perf_db"]
    sql_port = 3306
    ssh_host = os.environ["dbsship"]
    ssh_user = os.environ["dbsshuser"]
    ssh_port = int(os.environ["dbsshport"])
    ssh_pass = os.environ["dbsshpassword"]

    with SSHTunnelForwarder(
        (ssh_host, ssh_port),
        ssh_username=ssh_user,
        ssh_password=ssh_pass,
        remote_bind_address=(sql_hostname, sql_port),
    ) as tunnel:
        sqlEngine = sqlalchemy.create_engine(
            "mysql+pymysql://{0}:{1}@{2}:{3}/{4}".format(
                sql_username,
                sql_password,
                sql_hostname,
                tunnel.local_bind_port,
                sql_main_database,
            )
        )
        conn = sqlEngine.connect()

        # save gemm performance tests:
        if "perf_gemm" in filename and "gemm_bilinear" not in filename:
            # write the ck_gemm_test_params table only needed once the test set changes
            # post_test_params(test_list,conn)
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_gemm_tflops"
        if "batched_gemm" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_batched_gemm_tflops"
        if "grouped_gemm" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_grouped_gemm_tflops"
        if "perf_conv_fwd" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_conv_fwd_tflops"
        if "perf_conv_bwd_data" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_conv_bwd_data_tflops"
        if "grouped_conv_fwd" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_grouped_conv_fwd_tflops"
        if "grouped_conv_bwd_data" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_grouped_conv_bwd_data_tflops"
        if "grouped_conv_bwd_weight" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_grouped_conv_bwd_weight_tflops"
        if "gemm_bilinear" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_gemm_bilinear_tflops"
        if "reduction" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_reduction_GBps"
        if "resnet50_N4" in filename:
            for i in range(1, 50):
                testlist.append("Layer%i" % i)
            table_name = "ck_resnet50_N4_tflops"
        if "resnet50_N256" in filename:
            for i in range(1, 50):
                testlist.append("Layer%i" % i)
            table_name = "ck_resnet50_N256_tflops"
        if "onnx_gemm" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_onnx_gemm_tflops"
        if "splitK_gemm" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_splitK_gemm_tflops"
        if "mixed_gemm" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_mixed_gemm_tflops"
        if "fmha_fwd" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_fmha_fwd_tflops"
        if "fmha_bwd" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_fmha_bwd_tflops"
        if "gemm_basic_fp16" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_basic_fp16_tflops"
        if "gemm_mem_pipeline_fp16" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_mem_pipeline_fp16_tflops"
        if "gemm_basic_bf16" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_basic_bf16_tflops"
        if "gemm_mem_pipeline_bf16" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_mem_pipeline_bf16_tflops"
        if "gemm_basic_fp8" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_basic_fp8_tflops"
        if "gemm_mem_pipeline_fp8" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_mem_pipeline_fp8_tflops"
        if "gemm_basic_bf8" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_basic_bf8_tflops"
        if "gemm_mem_pipeline_bf8" in filename:
            for i in range(1, len(results) + 1):
                testlist.append("Test%i" % i)
            table_name = "ck_tile_gemm_mem_pipeline_bf8_tflops"

        tflops_base = get_baseline(table_name, conn)
        store_new_test_result(
            table_name,
            results,
            testlist,
            branch_name,
            node_id,
            gpu_arch,
            compute_units,
            rocm_vers,
            hip_vers,
            environment,
            sqlEngine,
        )
        conn.close()

    # compare the results to the baseline if baseline exists
    regression = 0
    regression = compare_test_to_baseline(tflops_base, results, testlist)
    return regression


if __name__ == "__main__":
    main()
